import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
import argparse
import os

color_defaults = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#e377c2',  # raspberry yogurt pink
    # '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#17becf',  # blue-teal
    '#1f77b4',  # blue-teal
]

def create_heatmap(xdist, ydist, target):
    heatmap, xedges, yedges = np.histogram2d(xdist, ydist, bins=15)
    ax = plt.gca()
    ax.set_facecolor('black')

    plt.imshow(heatmap.T,  cmap='magma', interpolation='nearest', norm=LogNorm())
    plt.colorbar(label='Density')
    plt.savefig(target)
    plt.close()
    ax.clear()


def create_plot(args, xdist, ydist, color_sequence, target):
    ax = plt.gca()
    ax.clear()
    ax.spines[['right', 'top']].set_visible(False)
    plt.scatter(xdist, ydist, c=color_sequence)
    

    plt.xlim(-args.xlim, args.xlim)
    # xlabels = ["0"] + ['{:}'.format(x) + 'M' for x in [(args.xlim//args.tick_len) *i for  i in range(1,int(args.xlim / args.tick_len) + 1)]]
    # xlabels = ["0"] + ['{:}'.format(x) + 'M' for x in [args.xlim//1e6/2, args.xlim//1e6]]
    # xlabels = ["0"] + ['{:}'.format(x) + 'M' for x in [args.xlim//1e6/2, args.xlim//1e6]]
    # xlabels = ["0"] + ['{}'.format(int(x * 10)) + 'M' for x in [args.xlim//1e6/5 * (i+1) for i in range(5)]]

    plt.tick_params(axis='both', which='major', labelsize=20)
    # plt.xticks(ticks=[0, args.xlim//2, args.xlim], labels=xlabels)
    label_strings = list(map(str, [-args.xlim, -args.xlim / 2, 0, args.xlim/2, args.xlim]))
    plt.xticks(ticks=[-args.xlim, -args.xlim / 2, 0, args.xlim/2, args.xlim], labels=label_strings)
    # plt.xticks(ticks=[0] + [args.xlim//args.tick_len * (i) for  i in range(1,int(args.xlim / args.tick_len) + 1)], labels=xlabels)
    label_strings = list(map(str, [-args.ylim, -args.ylim / 2, 0, args.ylim/2, args.ylim]))
    plt.yticks(ticks=[-args.ylim, -args.ylim / 2, 0, args.ylim/2, args.ylim], labels=label_strings)
    plt.ylim(-args.ylim, args.ylim)
    plt.grid()
    # plt.ylim(0, 270)
    # plt.xlabel(xlabel)
    if len(args.ylabel) > 0: plt.ylabel(args.ylabel, fontsize=15)
    if len(args.xlabel) > 0: plt.xlabel(args.xlabel, fontsize=15)
    # plt.title(title)
    # plt.legend(loc=2)
    # plt.figure(figsize = (8, 8))
    plt.savefig(target)
    ax.clear()


def plot_achieved_desired():
    parser = argparse.ArgumentParser(description='RL')
    parser.add_argument('--path', default='./data')
    parser.add_argument('--title', default="small")
    parser.add_argument('--target', default='./')
    parser.add_argument('--xlim', type=int, default=2)
    parser.add_argument('--ylim', type=float, default=2)
    parser.add_argument('--ylabel', default="")
    parser.add_argument('--xlabel', default="")
    args = parser.parse_args()
    

    achieved_desired = np.load(os.path.join(args.path, "achieved_desired0.npy"))[:,0]
    print(achieved_desired.shape)
    success, x_distance_base, y_distance_base, x_distance_achieved, y_distance_achieved = achieved_desired[:,0], achieved_desired[:,1], achieved_desired[:,2], achieved_desired[:,3], achieved_desired[:,4]
    color_sequence = np.array([color_defaults[0] for i in range(len(success))])
    color_sequence[success == 1] = color_defaults[3]
    x_achieved_success, y_achieved_success = x_distance_achieved[success==1][:3000], y_distance_achieved[success==1][:3000]
    x_achieved_fail, y_achieved_fail = x_distance_achieved[success==0][:3000], y_distance_achieved[success==0][:3000]
    x_distance_base, y_distance_base, x_distance_achieved, y_distance_achieved = x_distance_base[:3000], y_distance_base[:3000], x_distance_achieved[:3000], y_distance_achieved[:3000]
    # print(achieved_desired[success == 1,3:][:200])
    print(len(achieved_desired[success==1,3:]))

    create_plot(args, x_distance_base, y_distance_base, color_defaults[0], target = os.path.join(args.target, "achieved_desired_base.svg"))
    create_plot(args, x_achieved_fail, y_achieved_fail, color_defaults[1], target = os.path.join(args.target, "achieved_desired_failed.svg"))
    create_plot(args, x_achieved_success, y_achieved_success, color_sequence[2], target = os.path.join(args.target, "achieved_desired_filtered.svg"))

    create_heatmap(x_distance_base, y_distance_base, target=os.path.join(args.target, "ad_heat_base.svg"))
    create_heatmap(x_achieved_fail, y_achieved_fail, target=os.path.join(args.target, "ad_heat_failed.svg"))
    create_heatmap(x_distance_achieved, y_distance_achieved, target=os.path.join(args.target, "ad_heat_achieved.svg"))
    create_heatmap(x_achieved_success, y_achieved_success, target=os.path.join(args.target, "ad_heat_success.svg"))

if __name__ == "__main__":
    plot_achieved_desired()
    # python Utils/plot_achieved_desired.py --xlim 2 --ylim 2 --path ./data/small/